/*
SMillepede.cxx

Implemenetation of the OOP wrapper of fortran code millepede, and also the operation 
of the whole alignment process

Author: Kun Liu, liuk@fnal.gov
Created: 05-01-2013
*/

#include <iostream>
#include <cmath>
#include <vector>
#include <list>
#include <algorithm>
#include <fstream>
#include <sstream>

#include <TGraphErrors.h>
#include <TCanvas.h>
#include <TMath.h>
#include <TAxis.h>
#include <TF1.h>

#include "KalmanUtil.h"
#include "SMillepede.h"

SMillepede::SMillepede()
{
  evalFile = NULL;
  evalTree = NULL;
  p_geomSvc = GeomSvc::instance();
}

SMillepede::~SMillepede()
{
  if(evalFile != NULL)
    {
      for(int i = 0; i < MILLEPEDE::NPLAN; i++) delete evalHist[i];
      
      evalFile->cd();
      evalTree->Write();
      evalFile->Close();
    }

  delete evalNode;
}

void SMillepede::init()
{
  //Initialization of counters
  nTracks = 0;
  for(int i = 0; i < MILLEPEDE::NPLAN; i++)
    {
      nHits[i] = 0;
    }

  //Initialization of the global parameters
  for(int i = 0; i < MILLEPEDE::NGLB; i++)
    {
      par_align[i] = 0.;
      err_align[i] = 0.;
    }

  //Initialization of millepede
  initMillepede();
}

bool SMillepede::acceptTrack(SRecTrack& recTrack)
{
  if(recTrack.getNHits() < MILLEPEDE::NHITSMIN) return false;
  if(recTrack.getChisq() > MILLEPEDE::CHISQMAX) return false;
  if(recTrack.getMomentumSt1() < MILLEPEDE::MOMMIN) return false;
  //if(!recTrack.isHodoMasked()) return false;

  return true;
}

bool SMillepede::acceptTrack(Tracklet& track)
{
  if(track.getNHits() < MILLEPEDE::NHITSMIN) return false;
  //if(track.chisq > MILLEPEDE::CHISQMAX) return false;
  if(track.getProb() < MILLEPEDE::PROBMIN) return false;
  if(1./track.invP < MILLEPEDE::MOMMIN) return false;

  return true;
}

void SMillepede::setEvent(SRawEvent* rawEvt, SRecEvent* recEvt)
{
  rawEvent = rawEvt;
  recEvent = recEvt;

  rawEvent->reIndex("oah");

  int nTracksTotal = recEvent->getNTracks();
  for(int i = 0; i < nTracksTotal; i++)
    {
      SRecTrack trk = recEvent->getTrack(i);
      if(!acceptTrack(trk)) continue;

      addTrack(trk);
      ++nTracks;
    }
}

void SMillepede::setEvent(TClonesArray* trks)
{
  int nTracksTotal = trks->GetEntries();
  for(int i = 0; i < nTracksTotal; i++)
    {
      Tracklet* trk = (Tracklet*)trks->At(i);
      if(!acceptTrack(*trk)) continue;

      addTrack(*trk);
      ++nTracks;
    }
}

void SMillepede::addTrack(Tracklet& trk)
{
  //Push meanningful nodes
  nodes.clear();
  for(std::list<SignedHit>::iterator iter = trk.hits.begin(); iter != trk.hits.end(); ++iter)
    {
      if(iter->hit.index < 0) 
	{
	  MPNode node_dummy(iter->hit.detectorID);
	  nodes.push_back(node_dummy);
	}
      else
	{
	  MPNode node_real(*iter, trk);
	  nodes.push_back(node_real);

	  ++nHits[iter->hit.detectorID - 1];
	}
    }

  //Push dummy nodes
  int detectorID_s3 = trk.hits.back().hit.detectorID > 18 ? 13 : 19;
  for(int i = 0; i < 6; i++)
    {
      MPNode node_dummy(detectorID_s3 + i);
      nodes.push_back(node_dummy);
    }
  std::sort(nodes.begin(), nodes.end());
  
  setSingleTrack();
  fillEvaluationTree();
}

void SMillepede::addTrack(SRecTrack& trk)
{
  std::vector<int> detectorIDs_all;
  for(int i = 1; i < MILLEPEDE::NPLAN; i++) detectorIDs_all.push_back(i);

  //Prepare the node list with real hits
  std::vector<int> detectorIDs;
  detectorIDs.clear();
  nodes.clear();
  for(int i = 0; i < trk.getNHits(); i++)
    {
      int hitIndex = trk.getHitIndex(i);
      Hit h = rawEvent->getHit(recEvent->getLocalID(abs(hitIndex)));
      h.driftDistance = h.driftDistance*(hitIndex/abs(hitIndex));

      Node node_kalman(h);
      node_kalman.setZ(trk.getZ(i));
      node_kalman.getSmoothed()._state_kf = trk.getStateVector(i);
      node_kalman.getSmoothed()._covar_kf = trk.getCovariance(i);

      MPNode node_mp(node_kalman);
      nodes.push_back(node_mp);

      detectorIDs.push_back(h.detectorID);
      ++nHits[h.detectorID - 1];
    }
  std::sort(detectorIDs.begin(), detectorIDs.end());

  //Insert dummy MPNodes for the detectors without hits
  std::vector<int> detectorIDs_miss(24);
  std::vector<int>::iterator iter = std::set_difference(detectorIDs_all.begin(), detectorIDs_all.end(), detectorIDs.begin(), detectorIDs.end(), detectorIDs_miss.begin());
  detectorIDs_miss.resize(iter - detectorIDs_miss.begin());
  for(unsigned int i = 0; i < detectorIDs_miss.size(); i++)
    {
      MPNode node_mp(detectorIDs_miss[i]);
      nodes.push_back(node_mp);
    }
  std::sort(nodes.begin(), nodes.end());

  setSingleTrack();
  fillEvaluationTree();
}

void SMillepede::constrainDetectors(int detectorID1, int detectorID2, int paraID)
{
  int index1 = (detectorID1-1)*MILLEPEDE::NPARPLAN + paraID;
  int index2 = (detectorID2-1)*MILLEPEDE::NPARPLAN + paraID;

  float rhs = 0.;
  float dercs[MILLEPEDE::NGLB];
  for(int i = 0; i < MILLEPEDE::NGLB; i++) dercs[i] = 0;

  dercs[index1] = 1.;
  dercs[index2] = -1.;

  constf_(dercs, &rhs);
}

void SMillepede::initMillepede()
{
  //Define the dimension of arrays in millepede
  int nGlobal = MILLEPEDE::NGLB;
  int nLocal = MILLEPEDE::NPARTRK;
  int nStdev = MILLEPEDE::NSTDEV;
  int iPrint = 1;

  initgl_(&nGlobal, &nLocal, &nStdev, &iPrint);

  //Set parameter initial value and resolution parameter
  //By default all parameters are fixed at zero
  for(int i = 1; i <= MILLEPEDE::NPLAN; i++)
    {
      setDetectorParError(i, 0, 0.0);
      setDetectorParError(i, 1, 0.0);
      setDetectorParError(i, 2, 0.1);

      //if(i <= 6) setDetectorParError(i, 0, 2);
    }

  //Fix the rotation of D2V and D2Vp to be zero
  //fixDetectorParameter(p_geomSvc->getDetectorID("D2V"), 1);
  //fixDetectorParameter(p_geomSvc->getDetectorID("D2Vp"), 1);
  //fixDetectorParameter(p_geomSvc->getDetectorID("D3pX"), 0);
  //fixDetectorParameter(p_geomSvc->getDetectorID("D3pXp"), 0);
  //fixDetectorParameter(p_geomSvc->getDetectorID("D1U"), 2);
  //fixDetectorParameter(p_geomSvc->getDetectorID("D3mVp"), 2);
  
  // Now pass the info above to millepede
  parglo_(par_align);
  for(int i = 0; i < MILLEPEDE::NGLB; i++)
    {
      int index = i+1;
      parsig_(&index, &err_align[i]);
    }

  //Constrains
  //1. Fix adjacent planes to be same
  for(int i = 1; i <= MILLEPEDE::NPLAN; i += 2)
    {
      constrainDetectors(i, i+1, 0);
      constrainDetectors(i, i+1, 1);
      //constrainDetectors(i, i+1, 2);
    }

  //2. Fix all 6 planes of station 3+ to be the same in z and phi
  for(int i = 14; i < 17; i += 2)
    {
      constrainDetectors(i, i+1, 0);
      constrainDetectors(i, i+1, 1);
    }

  //3. a special one: the global offsets of station 3+ is fixed
  float rhs = 0.;
  float dercs[MILLEPEDE::NGLB];

  for(int k = 0; k < MILLEPEDE::NGLB; k++) dercs[k] = 0.;
  dercs[MILLEPEDE::NPARPLAN*12 + 2] = 1.;
  dercs[MILLEPEDE::NPARPLAN*13 + 2] = 1.;
  dercs[MILLEPEDE::NPARPLAN*14 + 2] = -1.;
  dercs[MILLEPEDE::NPARPLAN*15 + 2] = -1.;
  dercs[MILLEPEDE::NPARPLAN*16 + 2] = 1.;
  dercs[MILLEPEDE::NPARPLAN*17 + 2] = 1.;

  constf_(dercs, &rhs);

  //Fix global positions
  for(int k = 0; k < MILLEPEDE::NGLB; k++) dercs[k] = 0.;
  for(int i = 0; i < MILLEPEDE::NPLAN; i++) dercs[i*MILLEPEDE::NPARPLAN + 0] = 1.;
  constf_(dercs, &rhs); 
 
  for(int k = 0; k < MILLEPEDE::NGLB; k++) dercs[k] = 0.;
  for(int i = 0; i < MILLEPEDE::NPLAN; i++) dercs[i*MILLEPEDE::NPARPLAN + 1] = 1.;
  constf_(dercs, &rhs); 
  
  //for(int k = 0; k < MILLEPEDE::NGLB; k++) dercs[k] = 0.;
  //for(int i = 0; i < MILLEPEDE::NPLAN; i++) dercs[i*MILLEPEDE::NPARPLAN + 2] = 1.;
  //constf_(dercs, &rhs); 

  //Initialize the iteration setting
  int iUnit = 11;
  float cFactor = 1000.;
  initun_(&iUnit, &cFactor);
}

void SMillepede::fixDetectorParameter(int detectorID, int parameterID)
{
  //initial error
  err_align[MILLEPEDE::NPARPLAN*(detectorID - 1) + parameterID] = 0.;

  //Contrain
  float dercs[MILLEPEDE::NGLB];
  float rhs = 0.;

  for(int i = 0; i < MILLEPEDE::NGLB; i++) dercs[i] = 0.;
  dercs[MILLEPEDE::NPARPLAN*(detectorID - 1) + parameterID] = 1.;

  constf_(dercs, &rhs);
}

void SMillepede::setSingleTrack()
{
  using namespace MILLEPEDE;

  //Global and local derivatives
  float dergb[NGLB];
  float derlc[NPARTRK];
  float meas;
  float sigma;

  //Fill the nodes to derivative arrays
  for(std::vector<MPNode>::iterator node = nodes.begin(); node != nodes.end(); ++node)
    {
      if(!node->isValid()) continue;
   
      //Initialization of all parameters
      zerloc_(dergb, derlc);

      //Get measurements
      int index = node->detectorID - 1;
      meas = node->meas;
      sigma = node->sigma;

      //Fill local derivarives
      derlc[0] = node->dwdx;
      derlc[1] = node->dwdy;
      derlc[2] = node->dwdtx;
      derlc[3] = node->dwdty;

      //Fill global derivatives
      dergb[NPARPLAN*index + 0] = node->dwdz;
      dergb[NPARPLAN*index + 1] = node->dwdphi;
      dergb[NPARPLAN*index + 2] = node->dwdw;

      //Book the local/global derivatives, measurement and error
      equloc_(dergb, derlc, &meas, &sigma);
    }

  /*
  //Add virtual node to contrain derlc at z = 40 cm (dump face)
  zerloc_(dergb, derlc);
  meas = 0.;
  sigma = 5.;

  derlc[1] = 1.;
  derlc[3] = 40.;

  equloc_(dergb, derlc, &meas, &sigma);

  //Add virtual node to contrain derlc at z = 275 cm (FMAG bend plane)
  zerloc_(dergb, derlc);
  meas = 0.;
  sigma = 5.;

  derlc[0] = 1.;
  derlc[2] = 275.;

  equloc_(dergb, derlc, &meas, &sigma);
  */

  //Perform local track fit
  fitloc_();
}

void SMillepede::fitAlignment()
{
  //Perform global fits
  fitglo_(par_align);

  //Retrieve error of the aligned parameters
  for(int i = 0; i < MILLEPEDE::NGLB; i++)
    {
      err_align[i] = errpar_(&i);
    }
}

void SMillepede::printResults(std::string outputFileName, std::string increamentFileName)
{
  using namespace MILLEPEDE;
  using namespace std;
  cout << nTracks << " tracks are used in the alignment. " << endl;

  fstream fout1, fout2;
  fout1.open(outputFileName.c_str(), ios::out);
  fout2.open(increamentFileName.c_str(), ios::out);
  for(int i = 0; i < NPLAN; i++)
    {
      cout << i+1 << "   " << p_geomSvc->getDetectorName(i+1) << "     " << nHits[i]
	   << "     " << par_align[i*NPARPLAN + 0] << " +/- " << err_align[i*NPARPLAN + 0]
           << "     " << par_align[i*NPARPLAN + 1] << " +/- " << err_align[i*NPARPLAN + 1]
           << "     " << par_align[i*NPARPLAN + 2] << " +/- " << err_align[i*NPARPLAN + 2]
	   << "     " << evalHist[i]->GetMean() << " +/- " << evalHist[i]->GetRMS() << endl;   

      fout1 << "     " << par_align[i*NPARPLAN + 0] + p_geomSvc->getPlaneZOffset(i+1) 
            << "     " << par_align[i*NPARPLAN + 1] + p_geomSvc->getPlanePhiOffset(i+1)
            << "     " << par_align[i*NPARPLAN + 2] + p_geomSvc->getPlaneWOffset(i+1) 
	    << "     " << evalHist[i]->GetRMS() << endl;   

      fout2 << "     " << par_align[i*NPARPLAN + 0] 
            << "     " << par_align[i*NPARPLAN + 1] 
            << "     " << par_align[i*NPARPLAN + 2] 
	    << "     " << evalHist[i]->GetMean() << endl;   
    }

  fout1.close();
  fout2.close();
}

void SMillepede::printQAPlots()
{
  using namespace MILLEPEDE;

  double dID[MILLEPEDE::NPLAN], edID[MILLEPEDE::NPLAN];
  double dw[MILLEPEDE::NPLAN], edw[MILLEPEDE::NPLAN];
  double dphi[MILLEPEDE::NPLAN], edphi[MILLEPEDE::NPLAN];
  double dz[MILLEPEDE::NPLAN], edz[MILLEPEDE::NPLAN];
  for(int i = 0; i < MILLEPEDE::NPLAN; i++)
    {
      dID[i] = i+1;
      edID[i] = 0.;

      dz[i] = par_align[i*NPARPLAN + 0] + p_geomSvc->getPlaneZOffset(i+1);
      edz[i] = err_align[i*NPARPLAN + 0];

      dphi[i] = par_align[i*NPARPLAN + 1] + p_geomSvc->getPlanePhiOffset(i+1);
      edphi[i] = err_align[i*NPARPLAN + 1];

      dw[i] = par_align[i*NPARPLAN + 2] + p_geomSvc->getPlaneWOffset(i+1);
      edw[i] = err_align[i*NPARPLAN + 2];
    }

  TGraphErrors w_vs_id(MILLEPEDE::NPLAN, dID, dw, edID, edw);
  TGraphErrors phi_vs_id(MILLEPEDE::NPLAN, dID, dphi, edID, edphi);
  TGraphErrors z_vs_id(MILLEPEDE::NPLAN, dID, dz, edID, edz);

  w_vs_id.SetMarkerStyle(8);
  phi_vs_id.SetMarkerStyle(8);
  z_vs_id.SetMarkerStyle(8);

  w_vs_id.SetTitle("");
  phi_vs_id.SetTitle("");
  z_vs_id.SetTitle("");

  w_vs_id.GetXaxis()->SetTitle("detectorID");
  phi_vs_id.GetXaxis()->SetTitle("detectorID");
  z_vs_id.GetXaxis()->SetTitle("detectorID");

  w_vs_id.GetYaxis()->SetTitle("#Deltau (cm)");
  phi_vs_id.GetYaxis()->SetTitle("#Delta#theta (rad)");
  z_vs_id.GetYaxis()->SetTitle("#Deltaz (cm)");

  w_vs_id.GetYaxis()->CenterTitle();
  phi_vs_id.GetYaxis()->CenterTitle();
  z_vs_id.GetYaxis()->CenterTitle();

  TCanvas c;
  c.cd(); c.SetGridx(); c.SetGridy(); w_vs_id.Draw("APL");
  c.SaveAs("w_vs_id.eps");
  c.cd(); c.SetGridx(); c.SetGridy(); phi_vs_id.Draw("APL");
  c.SaveAs("phi_vs_id.eps");
  c.cd(); c.SetGridx(); c.SetGridy(); z_vs_id.Draw("APL");
  c.SaveAs("z_vs_id.eps");

  if(evalTree == NULL) return;

  TCanvas r;
  r.Divide(6, 4);
  for(int i = 0; i < MILLEPEDE::NPLAN; i++)
    {
      r.cd(i+1);
      evalHist[i]->Draw();
    }
  r.SaveAs("residuals.eps");
}

void SMillepede::bookEvaluationTree(std::string evalFileName)
{
  evalFile = new TFile(evalFileName.c_str(), "recreate");
  evalTree = new TTree("save", "save");
  evalNode = new MPNode(0);

  evalTree->Branch("runID", &runID, "runID/I");
  evalTree->Branch("eventID", &eventID, "eventID/I");
  evalTree->Branch("MPnode", &evalNode, 256000, 99);

  //Evaluation histigrams
  for(int i = 0; i < MILLEPEDE::NPLAN; i++)
    {
      evalHist[i] = new TH1D(p_geomSvc->getDetectorName(i+1).c_str(), p_geomSvc->getDetectorName(i+1).c_str(), 100, -0.5*p_geomSvc->getPlaneSpacing(i+1), 0.5*p_geomSvc->getPlaneSpacing(i+1));
    }
}

void SMillepede::fillEvaluationTree()
{
  if(evalFile == NULL) return;

  //runID = rawEvent->getRunID();
  //eventID = rawEvent->getEventID();
  for(std::vector<MPNode>::iterator node = nodes.begin(); node != nodes.end(); ++node)
    {
      if(!node->isValid()) continue;

      *evalNode = *node;
      evalTree->Fill();
      evalHist[node->detectorID-1]->Fill(node->meas);
    }
}
